#!/usr/bin/python
# coding:utf-8
import keras
from tflearn.layers.core import fully_connected
from keras.datasets import mnist
from keras import backend as K
from keras.layers import Input, Dense
from keras.models import Model

num_classes = 10
img_rows, img_cols = 28, 28


(trainX, trainY), (testX, testY) = mnist.load_data()

# 根据不同的底层设置输入层的格式
if K.image_data_format() == "channels_first":
    trainX = trainX.reshape(trainX.shape[0], 1, img_rows, img_cols)
    testX = testX.reshape(testX.shape[0], 1, img_rows, img_cols)

    input_shape = (1, img_rows, img_cols)  # mnist都是黑白图片, 所以第一维为1
else:
    trainX = trainX.reshape(trainX.shape[0], img_rows, img_cols, 1)
    testX = testX.reshape(testX.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

# 将像素转化为0~1间的实数
trainX = trainX.astype("float32")
testX = testX.astype("float32")
trainX /= 255.0
testX /= 255.0

# 将标准答案转化为需要的格式(one-hot)
trainY = keras.utils.to_categorical(trainY, num_classes)
testY = keras.utils.to_categorical(testY, num_classes)

input1 = Input(shape=(784,), name="input1")
input2 = Input(shape=(10,), name="input2")

x = Dense(1, activation="relu")(input1)  # 只有一个隐藏节点的全连接网络
output1 = Dense(10, activation="softmax",name="output1")(x)  # 输出层

y = keras.layers.concatenate([x,input2])  # 将隐藏节点的输出与input2拼接
output2 = Dense(10, activation="softmax", name="output2")(y)

model = Model(inputs=[input1,input2], outputs=[output1,output2])  # 定义模型, 秩序将所有输入输出给出即可

model.compile(
    loss=keras.losses.categorical_crossentropy,  # 指定损失函数
    optimizer=keras.optimizers.SGD(),  # 指定优化方法
    loss_weights=[1, 0.1],  # 为不同的损失指定权重
    metrics=["accuracy"]  # 指定优化目标
)
model.fit(
    [trainX, trainY],[trainY,trainY],
    batch_size=128, epochs=20,
    validation_data=([testX,testY],[testY,testY])
)